import abc
import random
from typing import Sequence, Any, Tuple

from gym import spaces as spaces

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv

directions = [
    (0, 0),
    (0, -1),
    (1, 0),
    (0, 1),
    (-1, 0)
]


class ParticleMomentum(MultiAgentSafetyEnv, abc.ABC):
    """
    Agents have five actions:
    0 = Do nothing
    1 = Move up
    2 = Right
    3 = Down
    4 = Left

    Furthermore, agents have some momentum (capped at 1 in each direction).
    The state is stored as (agent 0's x pos, agent 0's y pos, agent 0's x momentum +1, agent 0's y momentum +1, [repeat for agent 1])
    The observation is in the same order
    """

    def __init__(self, world_size: int = 10, agents_observe_momentum: bool = False, randomize_starts: bool = False,
                 collision_reward=-30, terminate_on_collision: bool = False):
        self.world_size = world_size
        self.num_relative_distances = world_size * 2 - 1
        self.agents_observe_others_momentum = agents_observe_momentum
        self.randomize_starts = randomize_starts
        self.collision_cost = collision_reward
        self.terminate_on_collision = terminate_on_collision

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        if self.agents_observe_others_momentum:
            obs_space = spaces.MultiDiscrete([self.num_relative_distances, self.num_relative_distances, 5, 5])
        else:
            obs_space = spaces.MultiDiscrete([self.num_relative_distances, self.num_relative_distances])
        return [obs_space] * 2

    def agent_actions_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Discrete(5)] * 2

    def state_space(self) -> spaces.Space:
        return spaces.MultiDiscrete([self.num_relative_distances, self.num_relative_distances, 5, 5])

    def initial_state(self):
        if self.randomize_starts:
            xrel, yrel = random.sample(range(self.num_relative_distances), 2)
            starting_locs = (xrel, yrel, 2, 2)  # Start without momentum
        else:
            starting_locs = (1, 1, 2, 2)

        return starting_locs, self.project_obs(starting_locs)

    def step(self, environment_state, joint_action: Sequence[Any]) -> Tuple[
        Any, Sequence[Any], Sequence[float], bool, bool]:

        rel_x, rel_y, rel_xvel, rel_yvel = environment_state
        action0, action1 = joint_action

        a0_impulse_x, a0_impulse_y = directions[action0]
        a1_impulse_x, a1_impulse_y = directions[action1]
        rel_impulse_x, rel_impulse_y = a0_impulse_x - a1_impulse_x, a0_impulse_y - a1_impulse_y

        def process_impulse(prev_val, prev_momentum, impulse):
            new_momentum = prev_momentum + impulse
            if new_momentum < 0:
                new_momentum = 0
            elif new_momentum > 4:
                new_momentum = 4

            new_val = prev_val + (new_momentum - 2)
            if new_val < 0:
                new_val = 0
                new_momentum = 2
                hit_wall = True
            elif new_val >= self.num_relative_distances:
                new_val = self.num_relative_distances - 1
                new_momentum = 2
                hit_wall = True
            else:
                hit_wall = False

            return new_val, new_momentum, hit_wall

        next_rel_x, next_rel_xvel, hit_x = process_impulse(rel_x, rel_xvel, rel_impulse_x)
        next_rel_y, next_rel_yvel, hit_y = process_impulse(rel_y, rel_yvel, rel_impulse_y)

        midpoint_value = self.world_size - 1

        collisions_or_crossings = next_rel_x == midpoint_value and next_rel_y == midpoint_value  # Collision
        for d_old_x, d_old_y, d_new_x, d_new_y in [
            (midpoint_value - nx, midpoint_value - ny, midpoint_value + nx, midpoint_value + ny) for nx in [-1, 0, 1]
            for ny in [-1, 0, 1]]:
            if rel_x == d_old_x and rel_y == d_old_y and next_rel_x == d_new_x and next_rel_y == d_new_y:  # Cross paths
                collisions_or_crossings = True

        unsafe_action = collisions_or_crossings or hit_x or hit_y

        new_env_state = (next_rel_x, next_rel_y, next_rel_xvel, next_rel_yvel)

        reached_goal = (next_rel_x, next_rel_y) == (self.num_relative_distances - 2, self.num_relative_distances - 2)

        done = reached_goal
        if unsafe_action and self.terminate_on_collision:
            done = True

        if reached_goal:
            reward = 100
        elif unsafe_action:
            reward = self.collision_cost
        else:
            reward = -1

        rewards = [reward, reward]

        return new_env_state, self.project_obs(new_env_state), rewards, done, (not unsafe_action)

    def project_obs(self, state) -> Sequence[Any]:
        if self.agents_observe_others_momentum:
            return tuple([tuple(state)] * 2)
        else:
            obs = state[:2]
            return (obs, obs)
